import numpy as np
import datetime
import torch
from scipy.stats import norm
from scipy.stats import t

from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import QuantileRegressor
import torch.utils.data as data
from torch.utils.data import DataLoader

def pipeline(clf, X_train, X_test, Y_train, Y_test, scores, params):
    '''
    A pipeline responsible for calculating all "scores" necessary for the following CFD based MI attacks:
    * Counterfactual distance attack (score: CFD, aka L2 norm to target logistic regression score)
    * Counterfactual distance LRT attack (score: likelihood ratio Lambda)
        * Global variance
        * Local variance
    Returns: scores, a dictionary storing all the aformentioned "scores."
    '''
    
    # COMPUTE PREDICTED PROBABILITIES
    epsd = params['epsd']
   
    if params['ensemble']:
        y_pred_train1 = y_pred_train0 = y_pred_test1 = y_pred_test0 = 0  # initialization
        for i in range(params['n_ensemble']):
            # take each ensemble model "equally as seriously" as each other model
            # in other words, take average predicted probabilities across ensemble models
            y_pred_train1 += (1/params['n_ensemble']) * clf[i].predict_proba(X_train)[:, 1]
            y_pred_train0 += (1/params['n_ensemble']) * clf[i].predict_proba(X_train)[:, 0]
            y_pred_test1 += (1/params['n_ensemble']) * clf[i].predict_proba(X_test)[:, 1]
            y_pred_test0 += (1/params['n_ensemble']) * clf[i].predict_proba(X_test)[:, 0]
    else:
        y_pred_train1 = clf.predict_proba(X_train)[:, 1]
        y_pred_train0 = clf.predict_proba(X_train)[:, 0]
        y_pred_test1 = clf.predict_proba(X_test)[:, 1]
        y_pred_test0 = clf.predict_proba(X_test)[:, 0]
        
    
    ########## COMPUTE PREDICTIONS ##########
    # "vanilla" predictions
    scores['preds_train'].append(
        np.log(y_pred_train1 + epsd) - np.log(y_pred_train0 + epsd))
    scores['preds_test'].append(
        np.log(y_pred_test1 + epsd) - np.log(y_pred_test0 + epsd))
    
    # LRT predictions
    shadow_preds_train, shadow_preds_test = compute_shadowpredictions(X_train, X_test,
                                                                      X_test, Y_test,
                                                                      params)
    
    # LRT PREDICTIONS BY ADVERSARY, USING GLOBAL VARIANCE
    # use absolute values of logits as scores
    vars_global_preds = np.var(np.r_[shadow_preds_train, shadow_preds_test], axis=1)
    if params['weighting'] == 'equal':
        var_global_preds = 1
    else:
        var_global_preds = np.mean(vars_global_preds)
    
    preds_lrt_score_train = compute_lambda_predictions(
        np.abs(np.log(y_pred_train1 + epsd) - np.log(y_pred_train0 + epsd)),
        shadow_preds_train,
        var_global_preds,
        global_variance=True)
    
    preds_lrt_score_test = compute_lambda_predictions(np.abs(np.log(y_pred_test1 + epsd) - np.log(y_pred_test0 + epsd)),
                                                      shadow_preds_test,
                                                      var_global_preds,
                                                      global_variance=True)
    
    scores['preds_lrt_train_global'].append(preds_lrt_score_train)
    scores['preds_lrt_test_global'].append(preds_lrt_score_test)
    
    # LRT PREDICTIONS BY ADVERSARY, USING LOCAL VARIANCE
    preds_lrt_score_train = compute_lambda_predictions(
        np.abs(np.log(y_pred_train1 + epsd) - np.log(y_pred_train0 + epsd)),
        shadow_preds_train,
        var_global_preds,
        global_variance=False)
    
    preds_lrt_score_test = compute_lambda_predictions(np.abs(np.log(y_pred_test1 + epsd) - np.log(y_pred_test0 + epsd)),
                                                      shadow_preds_test,
                                                      var_global_preds,
                                                      global_variance=False)
    
    scores['preds_lrt_train_local'].append(preds_lrt_score_train)
    scores['preds_lrt_test_local'].append(preds_lrt_score_test)
    
    ########## COMPUTE DISTANCES ##########
    # vanilla distances
    # using closed-form solutions for L2 distance to boundary since the unerlying model is linear
    if params['dp_laplace']:
        ones_train = np.ones(np.shape(y_pred_train1))
        ones_test = np.ones(np.shape(y_pred_test1))
        # add Lap(1/eps) noise to train predicted probabilities
        y_pred_train1 = y_pred_train1 + np.random.laplace(loc=0., scale=ones_train / params['epsilon'], size=np.shape(y_pred_train1))
        y_pred_train1 = np.clip(y_pred_train1, 0, 1)
        # add Lap(1/eps) noise to test predicted probabilities
        y_pred_test1 = y_pred_test1 + np.random.laplace(loc=0., scale=ones_test / params['epsilon'], size=np.shape(y_pred_test1))
        y_pred_test1 = np.clip(y_pred_test1, 0, 1)
        y_pred_train0 = 1 - y_pred_train1
        y_pred_test0 = 1 - y_pred_test1
    # calculate logit scores
    f_train = np.log(y_pred_train1 + epsd) - np.log(y_pred_train0 + epsd)
    f_test = np.log(y_pred_test1 + epsd) - np.log(y_pred_test0 + epsd)

    if params['ensemble']:
        w_train = np.zeros(X_train.shape[1])
        for i in range(params['n_ensemble']):
            w_train += clf[i].coef_[0]
    else:
        w_train = clf.coef_[0]

    train_dist, test_dist = vanilla_distances(f_train, f_test, w_train)
    
    # vanilla distances
    scores['dists_train'].append(train_dist)
    scores['dists_test'].append(test_dist)
    
    # train shadow models
    shadow_dists_train, shadow_dists_test = compute_shadowdistances(X_train, X_test, X_test, Y_test, params)

    # LRT, GLOBAL VARIANCE
    vars_global_dists = np.var(np.r_[shadow_preds_train, shadow_preds_test], axis=1)
    if params['weighting'] == 'equal':
        var_global_dists = 1
    else:
        var_global_dists = np.mean(vars_global_preds)
    
    dists_lrt_score_train = compute_lambda_predictions(train_dist,
                                                       shadow_dists_train,
                                                       var_global_dists,
                                                       global_variance=True)
    
    dists_lrt_score_test = compute_lambda_predictions(test_dist,
                                                      shadow_dists_test,
                                                      var_global_dists,
                                                      global_variance=True)
    

    scores['dists_lrt_train_global'].append(dists_lrt_score_train)
    scores['dists_lrt_test_global'].append(dists_lrt_score_test)
    
    # LRT, LOCAL VARIANCE
    dists_lrt_score_train = compute_lambda_predictions(train_dist,
                                                       shadow_dists_train,
                                                       var_global_dists,
                                                       global_variance=False)
    
    dists_lrt_score_test = compute_lambda_predictions(test_dist,
                                                      shadow_dists_test,
                                                      var_global_dists,
                                                      global_variance=False)
    
    scores['dists_lrt_train_local'].append(dists_lrt_score_train)
    scores['dists_lrt_test_local'].append(dists_lrt_score_test)
   
    return scores

# COUNTERFACTUAL DISTANCE COMPUTATION
def vanilla_distances(score_train, score_test, w_train):
    '''
    Use the logit scores log[P(y=1|x)/(1-P(y=1|x))] and the data weights
    to compute counterfactual distances, i.e. standard L2 distances to decision boundary.
    ---------------------------------------------------------------------------
    Args:
        f: logit scores
        w: weights of trained model
    Returns: counterfactual distances (L2 norms to decision boundary)
    '''
    n_train = score_train.shape[0]  # number of training samples
    n_test = score_test.shape[0]    # number of test samples
    
    # u * v / ||v||^2
    train_deltas = - score_train.reshape(-1, 1) / np.linalg.norm(w_train, ord=2) ** 2 * np.tile(w_train, (n_train, 1))
    test_deltas = - score_test.reshape(-1, 1) / np.linalg.norm(w_train, ord=2) ** 2 * np.tile(w_train, (n_test, 1))
    
    ones_matrix = np.ones((n_train, 1))
    
    train_dist = np.linalg.norm(train_deltas, ord=2, axis=1)
    test_dist = np.linalg.norm(test_deltas, ord=2, axis=1)
    
    return train_dist, test_dist

def _get_stable_logit_loss(label,
                           prediction,
                           eps=1e-5):
    if label == 1:
        stable_logit_loss = np.log(prediction + eps) - np.log((1-prediction) + eps)
    else:
        stable_logit_loss = np.log((1 - prediction) + eps) - np.log(prediction + eps)
    return stable_logit_loss


# LIKELIHOOD RATIO / LRT ATTACK RELATED COMPUTATIONS
def compute_lambda_predictions(predictions: np.array,
                               shadow_predictions: np.array,
                               global_var: float,
                               global_variance=True,
                               eps=1e-5):
    '''
    This function is called by the adversary during shadow model training. Used to compute an estimated likelihood ratio (using maximum likelihood estimation) of the distributions of CFDs when (1) $x$ is included in the adversary's training set, and (2) when $x$ is a test point. A larger likelihood ratio suggests NON-MEMBER, whereas a lower likelihood ratio suggests MEMBER.
    Args:
        predictions: the CFDs of the data entries that the adversary does membership inference on
        shadow_predictions: the CFDs of the data entries used to train the shadow models
        global_var: global variance of the estimated CFD distribution/model
        global_variance: (boolean) whether to use global or local variance in maximum likelihood estimation
        eps: small constant added to local variance (for robustness)
    Returns: estimated Lambda, i.e. likelihood ratio values
    '''

    # calculate sample mean and variance of shadow model predictions
    # (or CFDs, depending on what this function is used for)
    mean_all_shadow_predictions = np.mean(shadow_predictions, axis=1)
    var_all_shadow_predictions = np.var(shadow_predictions, axis=1, ddof=1) # calculate local variance

    if global_variance:
        var = global_var
    else:
        var = var_all_shadow_predictions + eps
    
    Z = ((mean_all_shadow_predictions - predictions) / np.sqrt(var)) #* np.sqrt(shadow_predictions.shape[1])
    cap_lambda = t.cdf(x=Z, loc=0, scale=1, df=shadow_predictions.shape[1]-1)

    return cap_lambda


def compute_shadowpredictions(X_train, X_test, X_shadow, Y_shadow, params, eps=1e-5):
    '''
    Compute log probabilities for training and test points w.r.t. the shadow models.
    '''
    # pre-allocation
    n_train = X_train.shape[0]
    n_test = X_test.shape[0]
    predictions_train = np.zeros((n_train, params['n_shadow_models']))
    predictions_test = np.zeros((n_test, params['n_shadow_models']))
    
    # compute shadow predictions
    for i in range(params['n_shadow_models']):
        ind_X_prime = np.random.choice(a=np.shape(X_shadow)[0], size=int(np.shape(X_shadow)[0] * params['frac']), replace=False)
        X_prime_shadow = X_shadow[ind_X_prime]
        Y_prime_shadow = Y_shadow[ind_X_prime]
        
        # fit logistic regression shadow model
        # shadow models are themselves not trained with differential privacy
        model = LogisticRegression(penalty=params['penalty'], C=params['C'], fit_intercept=True, max_iter=2500).fit(X_prime_shadow, Y_prime_shadow)
        y_pred_train1 = model.predict_proba(X_train)[:,1]
        y_pred_test1 = model.predict_proba(X_test)[:,1]
        
        # compute absolute value of logit scores
        predictions_train[:,i] = np.abs(np.log(y_pred_train1 + eps) - np.log(1-y_pred_train1 + eps))
        predictions_test[:,i] = np.abs(np.log(y_pred_test1 + eps) - np.log(1-y_pred_test1 + eps))
    return predictions_train, predictions_test


def compute_shadowdistances(X_train, X_test, X_shadow, Y_shadow, params, epsd=1e-5):
    '''
    Compute L2 distances to the decision boundary for training and test points w.r.t. the shadow models.
    '''
    # preallocate
    n_train = X_train.shape[0]
    n_test = X_test.shape[0]
    distances_train = np.zeros((n_train, params['n_shadow_models']))
    distances_test = np.zeros((n_test, params['n_shadow_models']))
    # compute distances: here we use close-form L2 norm solution since the underlying model is linear
    for i in range(params['n_shadow_models']):
        ind_X_prime = np.random.choice(a=np.shape(X_shadow)[0], size=int(np.shape(X_shadow)[0] * params['frac']), replace=False)
        X_prime_shadow = X_shadow[ind_X_prime]
        Y_prime_shadow = Y_shadow[ind_X_prime]

        model = LogisticRegression(penalty=params['penalty'], C=params['C'], fit_intercept=True, max_iter=2500).fit(X_prime_shadow, Y_prime_shadow)
        f_train = np.log(model.predict_proba(X_train)[:, 1] + epsd) - np.log(model.predict_proba(X_train)[:, 0] + epsd) # train logit scores
        w_train = model.coef_[0] # model weights
        f_test = np.log(model.predict_proba(X_test)[:, 1] + epsd) - np.log(model.predict_proba(X_test)[:, 0] + epsd) # test logit scores
        train_dist, test_dist = vanilla_distances(f_train, f_test, w_train)
        
        distances_train[:, i] = train_dist
        distances_test[:, i] = test_dist
    return distances_train, distances_test
